import fire


def is_acceptable_response(content):
    ### normal response (not an error)
    if not content.startswith('[ERROR]'): return True

    ### acceptable errors (we can tolerate these errors due to content filtering)
    if 'content_filter' in content: return True
    if 'content filtering policies' in content: return True

    return False


def is_content_filtered(content):
    if 'error' in content and 'content filtering policies' in content: return True


def getOutputFromGPTV_multi_2(headers, payload):
    import os
    import requests
    import json
    import time

    while True:
        print('\n(trying...) ->', )
        try:
            response = requests.post(os.environ["GPTV_API_BASE"], headers=headers, data=json.dumps(payload))
            responseJson = json.loads(response.text)
            content = response.text

        except Exception as e_msg:
            content = '[ERROR] ' + str(e_msg)
 
        if content.startswith('{"error"'):
            content = '[ERROR] ' + content

        if 'exceeded call rate limit' in content:
            content = '[ERROR] ' + content

        if is_acceptable_response(content):
            break

        # retry for unacceptable response
        print('\n(retry later...) ->', content)
        time.sleep(1)

    ########################################
    #  if content is filtered due to content filtering policy of GPT-4V, we return empty string
    if is_content_filtered(content):
        return ""
    ########################################
    
    print(responseJson["choices"][0]["message"]["content"])
    return responseJson["choices"][0]["message"]["content"]


def example(img_path, query, max_tokens=100):
    import os
    import requests
    import base64
    from azfuse import File

    # img_path = "data/aokvqa/som/semantic-sam_slider_2.0/COCO_val2014_000000453001.jpg"

    # Configuration
    with File.open(img_path, "rb") as f:
        bytes = f.read()
        encoded_image = base64.b64encode(bytes).decode('ascii')


    GPT4V_KEY = os.environ.get("GPT4V_KEY")
    headers = {
        "Content-Type": "application/json",
        "api-key": GPT4V_KEY,
        'cogsvc-openai-gptv-disable-faceblur': 'true',
        'logprobs': 'true',
    }

    payload = {
    "messages": [
        {
            "role": "user",
            "content": [
                # "Generate a description about the image.",
                {"image": encoded_image},
                query,
            ],
        }
    ],
    # "temperature": 0.7,
    # "top_p": 0.95,
    # "max_tokens": 800
    "max_tokens": max_tokens,
    }
    # import ipdb; ipdb.set_trace()
    return getOutputFromGPTV_multi_2(headers, payload)


def inf_on_ours(split="docci_ambguity", image_folder = "<DATA_FOLDER>"):
    from azfuse import File
    import os
    import json
    split2data = {
        "docci_ambguity": "<DATA_FOLDER>docci/docci_ambiguity_dev.eval.jsonl",
        "docci_pred": "<DATA_FOLDER>docci/docci_pred_dev.eval.jsonl",
        "docci_complex": "<DATA_FOLDER>docci/docci_complex_dev.eval.jsonl",
        "docci_know": "<DATA_FOLDER>docci/docci_know_dev.eval.jsonl",
        "unk_vqa_validated": "<DATA_FOLDER>/vqav2/vqa_k_test_noun_dedup_sampled_1_sft_llaval_idk.human_valid_rewrite.eval.jsonl"
    }
    output_folder = f"<DATA_FOLDER>/gpt4v_output/{split}/"

    data_file = split2data[split]
    with File.open(data_file) as f:
        data = []
        for l in f.readlines():
            data.append(json.loads(l))
    if "docci" in split:
        data = data[:100]
    else:
        data = [d for d in data if "gqa" not in d["image"]]
        data = data[:100]
    all_outputs = []
    for d in data:
        img_path = os.path.join(image_folder, d["image"])
        qid = d["question_id"]
        output_file = os.path.join(output_folder, f"{qid}.json")
        if File.isfile(output_file):
            with File.open(output_file, "r") as f:
                all_outputs.append(json.loads(f.read()))
            continue
        output_text = example(img_path, query=d["text"])
        ans_item ={
            "question_id": qid,
            "prompt": d["text"],
            "answer_id": "gpt4v_"+qid,
            "model_id": "gpt4v",
            "text": output_text,
            "metadata": {}}
        with File.open(output_file, "w") as f:
            f.write(json.dumps(ans_item))
        all_outputs.append(ans_item)
    
    output_file = f"<DATA_FOLDER>/gpt4v_output/{split}/merge.jsonl"
    with File.open(output_file, "w") as f:
        for item in all_outputs:
            f.write(json.dumps(item) + "\n")
    
    gt_file = f"<DATA_FOLDER>/gpt4v_output/{split}/gt.jsonl"
    with File.open(gt_file, "w") as f:
        for item in data:
            f.write(json.dumps(item) + "\n")


if __name__ == '__main__':
    fire.Fire()